import os
import numpy
from jqc import jqc_plot
from scipy import constants
from matplotlib import pyplot
from diatom import Hamiltonian
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch,Rectangle
from matplotlib.collections import LineCollection
from sympy.physics.wigner import wigner_3j,wigner_9j
from matplotlib.colors import LogNorm,LinearSegmentedColormap
from matplotlib.ticker import (
    AutoLocator, AutoMinorLocator)

def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
                norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
                legend=False,ax=None):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''
    if ax == None:
        ax = pyplot.gca()

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):#to check for numerical input -- this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm,
                        linewidth=linewidth,zorder=1.25)

    ax.add_collection(lc)

    return lc

def dipolez(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,0,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat

#set up environment
jqc_plot.plot_style("normal")
grid = gridspec.GridSpec(2,2,width_ratios=[1,0.05],height_ratios=[1.5,1])
cwd = os.path.dirname(os.path.abspath(__file__))
root = os.path.dirname(cwd)

fig = pyplot.figure("EVERYTHING")

#set some constants

colour_dict_twk_blue = {
    "red" : [(0.0,244/255,244/255),
            (0.6,0,0),
            (1.0,0,0)] ,
    "green" : [(0.0,234/255,234/255),
            (0.6,70/255.0,70/255.0),
            (1.0,70/255,70/255)],
    "blue" : [(0.0,168/255,168/255),
            (0.6,127/255,127/255),
            (1.0,127/255,127/255)]
}
colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                   (0.25, .5, .5),
                   (0.5, 1., 1.),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",
                                                colour_dict_twk_blue_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_blue)

Nmax =5
I1 = 3/2
I2 = 7/2

I1_state = numpy.zeros(int(2*I1+1))

I1_state[0]=1

I2_state = numpy.zeros(int(2*I2+1))

I2_state[0]=1

I_State = numpy.kron(I1_state,I2_state)


h = constants.h

RbCs = Hamiltonian.RbCs

colours = jqc_plot.colours

Nvec,I1vec,I2vec = Hamiltonian.Generate_vecs(Nmax,I1,I2)

F = Nvec+I1vec+I2vec
indices =[]
for N in range(0,Nmax+1):
    for MN in range(N,-(N+1),-1):
        for MI1 in numpy.arange(I1,-(I1+1),-1):
            for MI2 in numpy.arange(I2,-(I2+1),-1):
                indices.append([N,MN,MI1,MI2])
Fz = F[2]

ax_Stark = fig.add_subplot(grid[0,0])
ax_dipole = fig.add_subplot(grid[1,0],sharex=ax_Stark)


ax_dipole.set_yscale('log')
ax_Stark.tick_params(labelbottom=False)

fpath = cwd+"\\Data\\"

fname = "Fig2_DC.csv"

Hyperfine_energy = numpy.genfromtxt(fpath+"Energies\\"+fname,
                                    delimiter=',')

E = 1e-2*Hyperfine_energy[0,:]
Hyperfine_energy = Hyperfine_energy[1:,:]

k=0

Nplotmax = 1
numberplot= numpy.sum([(2*x+1)*32 for x in range(Nplotmax+1)])

try:
    d = numpy.genfromtxt(fpath+"TDM\\"+fname,delimiter=',',dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\Sorted\\multithread_N5_states_2kV.npy")
    print("Calculating dipole moments")
    dz = dipolez(Nmax,1)
    dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    d = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dz,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\Sorted\\Ncalc5_2kV_N1_TDMz.csv",d,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\Sorted\\Ncalc5_2kV_N1_TDMz.csv")

try:
    mF = numpy.genfromtxt(fpath+"MF\\"+fname,delimiter=',')
    print("loaded mF")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\Sorted\\multithread_N5_states_2kV.npy")
    print("Calculating mF")
    Nvec,I1vec,I2vec = Hamiltonian.Generate_vecs(Nmax,I1,I2)
    F = Nvec+I1vec+I2vec
    Fz=F[2]
    mF = numpy.round(numpy.einsum('ik,ij,jk->k',
            Hyperfine_States[:,:,1],Fz,Hyperfine_States[:,:,1]).real)
    numpy.savetxt(fpath+"\\Sorted\\Ncalc5_2kV_mF.csv",mF,delimiter=',')
    print("Saved mF to:"+fpath+"\\Sorted\\Ncalc5_2kV_mF.csv")
colours_fixed = [colours['red'],colours['grayblue'],colours['green'],
                colours['purple']]

for i in range(numberplot):
    index = indices[i]
    if index[0] ==1:
        TDM = d[i-32,:]
        ax_Stark.plot(1e-3*E,1e-6*Hyperfine_energy[i,:]/h-980,
                    color=colours['sand'],alpha=0.5,zorder=1.0)


        cl = colorline(1e-3*E,1e-6*Hyperfine_energy[i,:]/h-980,3*numpy.abs(TDM)**2,
                        cmap='RbCs_map_tweak_blue',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=ax_Stark)
        if mF[i] ==+5:
            ax_dipole.plot(1e-3*E,numpy.abs(TDM),color=colours_fixed[k])
            k+=1
    elif index[0]==0:
        if mF[i] ==5:
            k+=1
            gs = i
            print(gs)

xlims =(0, 0.125)
ylims = (-1.05,0.55)

ax_dipole.set_yticks([1e-4,1e-3,1e-2,1e-1,1])
ax_dipole.set_yticks([2e-4,3e-4,5e-4,6e-4,7e-4,8e-4,9e-4,
                    2e-3,3e-3,4e-3,5e-3,6e-3,7e-3,8e-3,9e-3,
                    2e-2,3e-2,4e-2,5e-2,6e-2,7e-2,8e-2,9e-2,
                    2e-1,3e-1,4e-1,5e-1,6e-1,7e-1,8e-1,9e-1],
                    minor=True)

fwd = lambda x: x*numpy.sqrt(3)
rev = lambda x: x/numpy.sqrt(3)

rect = Rectangle((xlims[0],ylims[0]),(xlims[1]-xlims[0]),(ylims[1]-ylims[0]),
                zorder=5,fill=False,edgecolor='k',lw=1.5)

ax_Stark.add_patch(rect)
ax_Stark.set_xlim(0,1)
ax_Stark.set_ylim(-12.5,50)
ax_Stark.set_ylabel("Energy$/h$ (MHz)")

ax_Stark.text(0.01,1.03,"+980 MHz",fontsize=15,clip_on=False,
                transform=ax_Stark.transAxes)

props={'arrowstyle':'->'}

ax_Stark.annotate("$M_N=0$",(0.5,20),(0.25,35),arrowprops=props)
ax_Stark.annotate("$M_N=\\pm1$",(0.5,-8),(0.65,5),arrowprops=props)

ax_dipole.set_xlabel("Electric Field, $E_z$ (kV$\\,$cm$^{-1}$)")
ax_dipole.set_ylabel("TDM, $|\\mu^{z}_{0i}|\,(\\mu_0)$")

ax_dipole.set_ylim(2e-4,1)

ax_dipole.annotate("2",xy=(.390,1/numpy.sqrt(3)),xytext=(.450,0.1),xycoords="data",
                arrowprops={'ls':':','lw':1,'arrowstyle':'-'},fontsize=15)

ax_dipole.annotate("1",xy=(.195,0.01),xytext=(.350,0.01),xycoords="data",
                arrowprops={'ls':':','lw':1,'arrowstyle':'-'},fontsize=15)

ax_dipole.annotate("0",xy=(.230,0.005),xytext=(.150,0.002),xycoords="data",
                arrowprops={'ls':':','lw':1,'arrowstyle':'-'},fontsize=15)

colax=fig.add_subplot(grid[:,1])
colax.set_title("$z$",fontsize=15,color=colours['blue'])
fig.colorbar(cl,cax = colax)
colax.set_ylabel("Relative Transition Strength")

ax_Stark.text(0.01,0.87,"(a)",transform=ax_Stark.transAxes,fontsize=20)
ax_dipole.text(0.01,0.09,"(b)",transform=ax_dipole.transAxes,fontsize=20)


pyplot.subplots_adjust(hspace=0.15,wspace=0.07,top=0.94,left=0.15,bottom=0.15,
                        right=0.85)

pyplot.savefig("fig2.pdf")
pyplot.savefig("fig2.png")
pyplot.show()
